{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastcore.test import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model_constructor.constructor import *\n",
    "from model_constructor.layers import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# New"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Body(\n",
       "  (layer_0): BasicLayer(\n",
       "    from 64 to 64, 2 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (sa): SimpleSelfAttention(\n",
       "        (conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_1): BasicLayer(\n",
       "    from 256 to 128, 2 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_2): BasicLayer(\n",
       "    from 512 to 256, 2 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_3): BasicLayer(\n",
       "    from 1024 to 512, 2 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "body = Body(ResBlock, conv_block=ConvBlockBottle, expansion=4, sa=True)\n",
    "body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2048, 4, 4])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 64, 32, 32)\n",
    "y = body(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Body(\n",
       "  (layer_0): BasicLayer(\n",
       "    from 64 to 64, 2 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_1): BasicLayer(\n",
       "    from 64 to 128, 2 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_2): BasicLayer(\n",
       "    from 128 to 256, 2 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_3): BasicLayer(\n",
       "    from 256 to 512, 2 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "body = Body(ResBlock, expansion=1)\n",
    "body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2048, 4, 4])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(16, 64, 32, 32)\n",
    "y = body(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Xresnet models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def xresnet18(**kwargs):\n",
    "    \"\"\"Constructs a xresnet-18 model. \"\"\"\n",
    "    return Net(stem_sizes=[32,32], block=ResBlock, blocks=[2, 2, 2, 2], expansion=1, **kwargs)\n",
    "def xresnet34(**kwargs):\n",
    "    \"\"\"Constructs axresnet-34 model. \"\"\"\n",
    "    return Net(stem_sizes=[32,32], block=ResBlock, blocks=[3, 4, 6, 3], expansion=1, **kwargs)\n",
    "def xresnet50(**kwargs):\n",
    "    \"\"\"Constructs axresnet-34 model. \"\"\"\n",
    "    return Net(stem_sizes=[32,32],block=ResBlock, blocks=[3, 4, 6, 3], conv_block=ConvBlockBottle, expansion=4, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 32, 32, 64]\n",
       "    (conv0): ConvLayer(\n",
       "      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv1): ConvLayer(\n",
       "      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv2): ConvLayer(\n",
       "      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 2 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 64 to 128, 2 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 128 to 256, 2 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 256 to 512, 2 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xresnet18()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 32, 32, 64]\n",
       "    (conv0): ConvLayer(\n",
       "      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv1): ConvLayer(\n",
       "      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv2): ConvLayer(\n",
       "      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 3 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 64 to 128, 4 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 128 to 256, 6 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_4): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_5): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 256 to 512, 3 blocks, expansion 1.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBasic(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xresnet34()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = xresnet50()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 32, 32, 64]\n",
       "  (conv0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv1): ConvLayer(\n",
       "    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv2): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Body(\n",
       "  (layer_0): BasicLayer(\n",
       "    from 64 to 64, 3 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_1): BasicLayer(\n",
       "    from 256 to 128, 4 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_3): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_2): BasicLayer(\n",
       "    from 512 to 256, 6 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_3): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_4): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_5): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_3): BasicLayer(\n",
       "    from 1024 to 512, 3 blocks, expansion 4.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBottle(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_2): ConvLayer(\n",
       "          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Head(\n",
       "  (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  (flat): Flatten()\n",
       "  (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 1000])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(8, 3, 128, 128)\n",
    "y = model(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Alternative"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Xresnet34():\n",
    "    def __init__(self, stem_sizes=[32,32], block=ResBlock, blocks=[3, 4, 6, 3], \n",
    "                 conv_block=ConvBlockBasic, expansion=1, **kwargs):\n",
    "        \n",
    "        self.stem_sizes = stem_sizes\n",
    "        self.block=block\n",
    "        self.blocks=blocks\n",
    "        self.conv_block=conv_block\n",
    "        self.expansion=expansion\n",
    "        self.kwargs=kwargs\n",
    "    \n",
    "    def __call__(self, **kwargs):\n",
    "        return Net(stem_sizes=self.stem_sizes,block=self.block, blocks=self.blocks, \n",
    "               conv_block=self.conv_block, expansion=self.expansion, **self.kwargs, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Xresnet50(Xresnet34):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv_block=ConvBlockBottle\n",
    "        self.expansion=4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xresnet34 = Xresnet34()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Body(\n",
       "  (layer_0): BasicLayer(\n",
       "    from 64 to 64, 3 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_1): BasicLayer(\n",
       "    from 64 to 128, 4 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_3): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_2): BasicLayer(\n",
       "    from 128 to 256, 6 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_3): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_4): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_5): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_3): BasicLayer(\n",
       "    from 256 to 512, 3 blocks, expansion 1.\n",
       "    (block_0): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): DownsampleLayer(\n",
       "        (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "        (idconv): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_2): ResBlock(\n",
       "      (convs): ConvBlockBasic(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (identity): Noop()\n",
       "      (merge): Noop()\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xresnet34().body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xresnet50 = Xresnet50()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 32, 32, 64]\n",
       "    (conv0): ConvLayer(\n",
       "      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv1): ConvLayer(\n",
       "      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (conv2): ConvLayer(\n",
       "      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 3 blocks, expansion 4.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 256 to 128, 4 blocks, expansion 4.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 512 to 256, 6 blocks, expansion 4.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_3): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_4): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_5): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 1024 to 512, 3 blocks, expansion 4.\n",
       "      (block_0): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): DownsampleLayer(\n",
       "          (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "          (idconv): ConvLayer(\n",
       "            (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_2): ResBlock(\n",
       "        (convs): ConvBlockBottle(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_2): ConvLayer(\n",
       "            (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (identity): Noop()\n",
       "        (merge): Noop()\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xresnet50()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[32, 48]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xresnet50.stem_sizes = [32,48]\n",
    "xresnet50.stem_sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "type object got multiple values for keyword argument 'expansion'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-15-0d1032376ca1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mxresnet50\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpansion\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-11-90889ac6980a>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m         return Net(stem_sizes=self.stem_sizes,block=self.block, blocks=self.blocks, \n\u001b[0;32m---> 14\u001b[0;31m                conv_block=self.conv_block, expansion=self.expansion, **self.kwargs, **kwargs)\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: type object got multiple values for keyword argument 'expansion'"
     ]
    }
   ],
   "source": [
    "# xresnet50(expansion=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['stem_sizes',\n",
       " 'block',\n",
       " 'blocks',\n",
       " 'conv_block',\n",
       " 'expansion',\n",
       " 'kwargs',\n",
       " '__module__',\n",
       " '__init__',\n",
       " '__doc__',\n",
       " '__call__',\n",
       " '__dict__',\n",
       " '__weakref__',\n",
       " '__repr__',\n",
       " '__hash__',\n",
       " '__str__',\n",
       " '__getattribute__',\n",
       " '__setattr__',\n",
       " '__delattr__',\n",
       " '__lt__',\n",
       " '__le__',\n",
       " '__eq__',\n",
       " '__ne__',\n",
       " '__gt__',\n",
       " '__ge__',\n",
       " '__new__',\n",
       " '__reduce_ex__',\n",
       " '__reduce__',\n",
       " '__subclasshook__',\n",
       " '__init_subclass__',\n",
       " '__format__',\n",
       " '__sizeof__',\n",
       " '__dir__',\n",
       " '__class__']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xresnet50.__dir__()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# end. xresnet.\n",
    "model_constructor\n",
    "by ayasyrev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_constructor.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_resnet.ipynb.\n",
      "Converted 03_xresnet.ipynb.\n",
      "Converted 80_test_layers.ipynb.\n",
      "Converted index.ipynb.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
